package top.lingkang.mm.orm;

import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.binding.MapperMethod;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.mapping.*;
import org.apache.ibatis.parsing.XNode;
import org.apache.ibatis.scripting.LanguageDriver;
import org.apache.ibatis.scripting.xmltags.DynamicSqlSource;
import org.apache.ibatis.scripting.xmltags.TextSqlNode;
import org.apache.ibatis.scripting.xmltags.XMLLanguageDriver;
import org.apache.ibatis.session.Configuration;
import top.lingkang.mm.annotation.Table;
import top.lingkang.mm.constant.IdType;
import top.lingkang.mm.error.MagicException;
import top.lingkang.mm.utils.MagicUtils;

import java.lang.reflect.Field;
import java.lang.reflect.Type;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

/**
 * curd
 *
 * @author lingkang
 * Created by 2024/3/11
 */
@Slf4j
public class BaseMapperDriver extends XMLLanguageDriver implements LanguageDriver {
    private final Field loadedResources, keyProperties, keyColumns, resultClass;
    private final Set<String> existsMapperMethod = new HashSet<>();
    private static final String curr = "current_" + new Date().getTime();// current_1732998718664
    // q2.param.current_1732998718664
    private static final String currentTime = BaseMapperSql.param_q2 + ".param." + curr;
    private static final ConcurrentMap<String, SqlSource> cacheSqlSource = new ConcurrentHashMap<>();

    public BaseMapperDriver() {
        try {
            loadedResources = Configuration.class.getDeclaredField("loadedResources");
            loadedResources.setAccessible(true);
            keyProperties = MappedStatement.class.getDeclaredField("keyProperties");
            keyProperties.setAccessible(true);
            keyColumns = MappedStatement.class.getDeclaredField("keyColumns");
            keyColumns.setAccessible(true);
            resultClass = ResultMap.class.getDeclaredField("type");
            resultClass.setAccessible(true);
        } catch (NoSuchFieldException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public SqlSource createSqlSource(Configuration configuration, XNode script, Class<?> parameterType) {
        return super.createSqlSource(configuration, script, parameterType);
    }

    @Override
    public SqlSource createSqlSource(Configuration configuration, String script, Class<?> parameterType) {
        if (!BaseMapperSql.isBaseMapperSql(script))
            return super.createSqlSource(configuration, script, parameterType);
        List<BaseInterface> interfaces = loadedBaseMapperInterface(configuration);
        BaseInterface interface_ = null;
        String sqlId = null;
        if (interfaces.size() == 1) {
            interface_ = interfaces.get(0);
            sqlId = interface_.getInterfaceStr() + "-" + script;
        } else {
            for (BaseInterface baseInterface : interfaces) {
                sqlId = baseInterface.getInterfaceStr() + "-" + script;
                if (!existsMapperMethod.contains(sqlId)) {
                    interface_ = baseInterface;
                    break;
                }
            }
            if (interface_ == null) {
                return cacheSqlSource.get(sqlId);
            }
        }

        existsMapperMethod.add(sqlId);

        MagicEntity entity = MagicEntityUtils.getMagicEntity(interface_.getEntityClass());
        switch (script) {
            case BaseMapperSql.selectAll:
                script = entity.getSelectTableSql();
                break;
            case BaseMapperSql.createQuery:
                script = entity.getSelectTableSql() + " ${q.sql}";
                break;
            case BaseMapperSql.selectById:
                if (entity.getIdIndex() == -1) {
                    script = "-1: " + interface_.getInter().getName() + ".selectById 查询的实体类没有 @Id 注解: " + entity.getClazz().getName() + " 查询失败";
                } else {
                    script = entity.getSelectTableSql() + " where ";
                    script += entity.getColumnName().get(entity.getIdIndex()) + "=#{" + BaseMapperSql.param_id + "}";
                }
                break;
            case BaseMapperSql.selectInIds:
                if (entity.getIdIndex() == -1) {
                    script = "-1: " + interface_.getInter().getName() + ".selectInIds 查询的实体类没有 @Id 注解: " + entity.getClazz().getName() + " 查询失败";
                } else {
                    script = "<script>" + entity.getSelectTableSql() + " where ";
                    script += entity.getColumnName().get(entity.getIdIndex()) + " in (" +
                            "<foreach collection=\"" + BaseMapperSql.param_id_3 +
                            "\" item=\"e\" separator=\",\">#{e}</foreach>)</script>";
                }
                break;
            case BaseMapperSql.selectByQuery:
                script = entity.getSelectTableSql() + " ${q.sql}";
                break;
            case BaseMapperSql.selectByQueryOne:
                script = entity.getSelectTableSql() + " ${q.sql}";
                break;
            case BaseMapperSql.existsById:
                if (entity.getIdIndex() == -1) {
                    script = "-1: " + interface_.getInter().getName() + ".existsById 实体类没有 @Id 注解: " + entity.getClazz().getName() + " 查询失败";
                } else {
                    script = "select count(*) from " + entity.getTableName() + " where ";
                    script += entity.getColumnName().get(entity.getIdIndex()) + "=#{" + BaseMapperSql.param_id + "}";
                }
                break;
            case BaseMapperSql.existsByEntity:
                if (entity.getIdIndex() == -1) {
                    script = "-1: " + interface_.getInter().getName() + ".existsByEntity 实体类没有 @Id 注解: " + entity.getClazz().getName() + " 查询失败";
                } else {
                    script = "select count(*) from " + entity.getTableName() + " where ";
                    script += entity.getColumnName().get(entity.getIdIndex()) + "=#{" + BaseMapperSql.param_id_2 + "}";
                }
                break;
            case BaseMapperSql.existsByQuery:
                script = "select count(*) from " + entity.getTableName() + " ${q.sql}";// BaseMapperSql.param_q
                break;
            case BaseMapperSql.selectCount:
                script = "select count(*) from " + entity.getTableName();
                break;
            case BaseMapperSql.selectCountByQuery:
                script = "select count(*) from " + entity.getTableName() + " ${q.sql}";
                break;
            case BaseMapperSql.selectColumn:
                script = "select ${q2.columns} from " + entity.getTableName() + " ${q2.sql}";
                break;
            case BaseMapperSql.selectColumnOne:
                script = "select ${q2.columns} from " + entity.getTableName() + " ${q2.sql}";
                break;
            case BaseMapperSql.insert:
                script = "insert into " + entity.getTableName() + "(" + MagicEntityUtils.getColumns(entity.getColumnName(), null) +
                        ") values(" + MagicEntityUtils.getInsertValues(entity, BaseMapperSql.magic_base_e) + ")";
                break;
            case BaseMapperSql.insertBatch:
                script = "<script>insert into " + entity.getTableName() + "(" + MagicEntityUtils.getColumns(entity.getColumnName(), null) + ") values" +
                        " <foreach collection=\"" + BaseMapperSql.magic_base_list + "\" index=\"\" item=\"e\"" +
                        " separator=\",\">(" + MagicEntityUtils.getInsertPrefixValues(entity, "e") + ")</foreach></script>";
                break;
            case BaseMapperSql.updateById:
                if (entity.getIdIndex() != -1)
                    script = "update " + entity.getTableName() + " set " + MagicEntityUtils.getSetColumns(entity, BaseMapperSql.magic_base_e) +
                            " where " + entity.getColumnName().get(entity.getIdIndex()) + "=#{" + BaseMapperSql.magic_base_e + "." +
                            entity.getFields().get(entity.getIdIndex()).getName() + "}";
                else script = "-1";
                break;
            case BaseMapperSql.updateByQuery:
                script = "update " + entity.getTableName() + " set " + MagicEntityUtils.getSetColumns(entity, BaseMapperSql.magic_base_e) +
                        " ${q.sql}";
                break;
            case BaseMapperSql.updateByColumn:
                if (entity.getAutoUpdateTimeColumn() == null) {
                    script = "update " + entity.getTableName() + " set ${q2.sql}";
                    break;
                }

                // 重写 v1.1.0+
                List<String> autoUpdateTimeColumn = entity.getAutoUpdateTimeColumn();
                StringBuilder updateTimeSql = new StringBuilder("update ").append(entity.getTableName()).append(" set ");
                for (String column : autoUpdateTimeColumn)
                    updateTimeSql.append(column)
                            .append("=#{").append(currentTime).append("},");

                String sql = updateTimeSql.append("${q2.sql}").toString();
                return new DynamicSqlSource(configuration, new TextSqlNode(sql));
            case BaseMapperSql.deleteById:
                if (entity.getIdIndex() == -1) {
                    script = "-1: " + interface_.getInter().getName() + ".deleteById 实体类没有 @Id 注解: " + entity.getClazz().getName() + " 删除失败";
                } else {
                    script = "delete from " + entity.getTableName() + " where ";
                    script += entity.getColumnName().get(entity.getIdIndex()) + "=#{" + BaseMapperSql.param_id + "}";
                }
                break;
            case BaseMapperSql.deleteInIds:
                if (entity.getIdIndex() == -1) {
                    script = "-1: " + interface_.getInter().getName() + ".deleteInIds 实体类没有 @Id 注解: " + entity.getClazz().getName() + " 删除失败";
                } else {
                    script = "<script>delete from " + entity.getTableName() + " where ";
                    script += entity.getColumnName().get(entity.getIdIndex()) + " in (" +
                            "<foreach collection=\"" + BaseMapperSql.param_id_2 +
                            "\" item=\"e\" separator=\",\">#{e}</foreach>)</script>";
                }
                break;
            case BaseMapperSql.deleteByQuery:
                script = "delete from " + entity.getTableName() + " ${q.sql}";
                break;
        }

        SqlSource sqlSource = super.createSqlSource(configuration, script, parameterType);
        cacheSqlSource.put(sqlId, sqlSource);// 缓存
        return sqlSource;
    }

    @Override
    public ParameterHandler createParameterHandler(MappedStatement mappedStatement, Object parameterObject, BoundSql boundSql) {
        if (mappedStatement.getSqlCommandType() == SqlCommandType.SELECT) {
            if (boundSql.getSql().startsWith("-1")) {
                throw new MagicException(boundSql.getSql());
            } else if (parameterObject instanceof MapperMethod.ParamMap) {
                MapperMethod.ParamMap map = (MapperMethod.ParamMap) parameterObject;
                if (map.containsKey(BaseMapperSql.param_q2)) {
                    Object object = map.get(BaseMapperSql.param_q2);
                    try {
                        resultClass.set(mappedStatement.getResultMaps().get(0), ((QueryColumn) object).getResultClass());
                    } catch (IllegalAccessException e) {
                        throw new MagicException(e);
                    }
                } else if (map.containsKey(BaseMapperSql.param_id_2)) {
                    Object object = map.get(BaseMapperSql.param_id_2);
                    MagicEntity entity = MagicEntityUtils.getMagicEntity(object.getClass());
                    Field field = entity.getFields().get(entity.getIdIndex());
                    Object value = MagicUtils.getValue(field, object);
                    if (value == null)
                        throw new MagicException("实体入参的 id 属性不能为空：" + object);
                    map.put(BaseMapperSql.param_id_2, value);
                }
            }
        } else if (mappedStatement.getSqlCommandType() == SqlCommandType.UPDATE) {// 更新时
            if (parameterObject == null)
                throw new MagicException("更新入参不能为空! ");
            MapperMethod.ParamMap map = (MapperMethod.ParamMap) parameterObject;
            if (map.containsKey(BaseMapperSql.magic_base_e)) {
                Object object = map.get(BaseMapperSql.magic_base_e);
                MagicEntity entity = MagicEntityUtils.getMagicEntity(object.getClass());
                if (!map.containsKey(BaseMapperSql.param_q) && entity.getIdIndex() == -1)
                    throw new MagicException("更新实体类没有 @Id 注解: " + entity.getClazz().getName());
                MagicEntityUtils.autoSetTime(entity, object, false);
                MagicEntityUtils.execPreUpdate(entity, object);
            } else if (map.containsKey(BaseMapperSql.param_q2) && currentTime.equals(boundSql.getParameterMappings().get(0).getProperty())) {// 更新当前时间
                UpdateColumn updateColumn = (UpdateColumn) map.get(BaseMapperSql.param_q2);
                updateColumn.param.put(curr, new Date());
            }
        } else if (mappedStatement.getSqlCommandType() == SqlCommandType.INSERT) {
            try {
                if (parameterObject != null && parameterObject.getClass() == MapperMethod.ParamMap.class) {
                    MapperMethod.ParamMap map = (MapperMethod.ParamMap) parameterObject;
                    if (map.containsKey(BaseMapperSql.magic_base_e)) {
                        Object object = map.get(BaseMapperSql.magic_base_e);
                        MagicEntity entity = MagicEntityUtils.getMagicEntity(object.getClass());
                        // 修复sqlite未设置自动生成ID设置错误问题
                        if (entity.getIdIndex() != -1) {
                            if (entity.getIdAnn().value() == IdType.AUTO) {// 自动生成并返回ID
                                // 设置id返回
                                keyProperties.set(mappedStatement, new String[]{entity.getFields().get(entity.getIdIndex()).getName()});
                                keyColumns.set(mappedStatement, new String[]{entity.getColumnName().get(entity.getIdIndex())});
                            } else if (entity.getIdAnn().value() == IdType.ASSIGN) {// 手动生成
                                MagicEntityUtils.setIdValue(object, mappedStatement.getConfiguration(), entity);
                            }
                        }
                        MagicEntityUtils.autoSetTime(entity, object, true);
                        MagicEntityUtils.execPreUpdate(entity, object);
                    } else {// 批量插入
                        Object object = map.get(BaseMapperSql.magic_base_list);
                        if (object == null)
                            throw new MagicException("插入对象列表不能为空！");
                        List<Object> list = (List<Object>) object;
                        if (list.isEmpty())
                            throw new MagicException("插入对象列表不能为空！");
                        MagicEntity entity = MagicEntityUtils.getMagicEntity(list.get(0).getClass());
                        if (entity.getIdIndex() != -1) {
                            if (entity.getIdAnn().value() == IdType.AUTO) {
                                // 设置id返回
                                keyProperties.set(mappedStatement, new String[]{entity.getFields().get(entity.getIdIndex()).getName()});
                                keyColumns.set(mappedStatement, new String[]{entity.getColumnName().get(entity.getIdIndex())});
                            } else if (entity.getIdAnn().value() == IdType.ASSIGN) {
                                for (Object o : list)// 设置id
                                    MagicEntityUtils.setIdValue(o, mappedStatement.getConfiguration(), entity);
                            }
                        }
                        MagicEntityUtils.autoSetTimeList(entity, list, true);
                        MagicEntityUtils.execPreUpdateList(entity, list);
                    }
                }
            } catch (Exception e) {
                throw new MagicException(e);
            }
        } else if (mappedStatement.getSqlCommandType() == SqlCommandType.DELETE) {
            if (parameterObject == null)
                throw new MagicException("删除的id对象不能为空! ");
            if (boundSql.getSql().startsWith("-1")) {
                throw new MagicException(boundSql.getSql());
            }
            MapperMethod.ParamMap map = (MapperMethod.ParamMap) parameterObject;
            if (map.containsKey(BaseMapperSql.param_id)) {
                Object object = map.get(BaseMapperSql.param_id);
                if (object.getClass().getAnnotation(Table.class) != null) {
                    MagicEntity entity = MagicEntityUtils.getMagicEntity(object.getClass());
                    if (entity.getIdIndex() == -1)
                        throw new MagicException("删除的实体类没有 @Id 注解: " + entity.getClazz().getName());
                    Object id = null;
                    try {
                        id = entity.getFields().get(entity.getIdIndex()).get(object);
                    } catch (IllegalAccessException e) {
                        throw new MagicException(e);
                    }
                    map.put(BaseMapperSql.param_id, id);
                }
            }
        }
        return super.createParameterHandler(mappedStatement, parameterObject, boundSql);
    }


    // ------------------------------------------------------------------------------------------------------
    private List<BaseInterface> loadedBaseMapperInterface(Configuration configuration) {
        List<BaseInterface> result = new ArrayList<>();
        try {
            HashSet<String> set = (HashSet) loadedResources.get(configuration);
            for (String inter : set) {
                if (inter.startsWith("interface")) {
                    Class<?> interface_ = getClass().getClassLoader().loadClass(inter.split(" ")[1]);
                    Type[] genericInterfaces = interface_.getGenericInterfaces();
                    if (genericInterfaces.length == 0)
                        continue;
                    String interfaceName = getBaseMapperInterfaceName(genericInterfaces);
                    if (interfaceName == null)
                        continue;

                    BaseInterface baseInterface = new BaseInterface();
                    baseInterface.setInterfaceStr(inter);
                    baseInterface.setInter(interface_);

                    // 获取entity类
                    String entityClassName = getEntityClassName(interfaceName, interface_.getName());
                    Class<?> entityClass = getClass().getClassLoader().loadClass(entityClassName);
                    baseInterface.setEntityClass(entityClass);
                    result.add(baseInterface);
                }
            }
        } catch (Exception e) {
            throw new MagicException(e);
        }
        return result;
    }

    /**
     * top.lingkang.mm.orm.BaseMapper<top.lingkang.test.entity.UserEntity>
     * 将返回实体类：top.lingkang.test.entity.UserEntity
     */
    private String getEntityClassName(String name, String interfaceName) {
        if (!name.endsWith(">")) {
            throw new IllegalStateException("BaseMapper必须设置好映射泛型实体类，例如：UserMapper extends BaseMapper<UserEntity>，当前接口：" + interfaceName);
        }
        String substring = name.substring(31);
        return substring.substring(0, substring.length() - 1);
    }

    /**
     * @return top.lingkang.mm.orm.BaseMapper<top.lingkang.test.entity.UserEntity>
     */
    private String getBaseMapperInterfaceName(Type[] genericInterfaces) {
        for (Type name : genericInterfaces) {
            if (name.getTypeName().startsWith("top.lingkang.mm.orm.BaseMapper"))
                return name.getTypeName();
        }
        return null;
    }
}
